{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# TensorIR 构建\n",
        "\n",
        "在本节中，将介绍在 Apache TVM Unity 中编写 TensorIR 函数的方法。本教程假设您已经熟悉 TensorIR 的基本概念。\n",
        "\n",
        "```{note}\n",
        ":class: alert alert-info\n",
        "\n",
        "本教程专注于构建 **独立的** TensorIR 函数。这里介绍的技术对于最终用户编译 Relax 模型来说不是必需的。\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用 TVMScript 构建 TensorIR\n",
        "\n",
        "通过 TVMScript 创建 TensorIR 函数的最简单方法。TVMScript 是 TVM Python 方言，代表 TVM 中的 TensorIR。\n",
        "\n",
        "### 标准格式\n",
        "\n",
        "以下是 `ir_module` 的完整格式，以及在TVMScript中：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-cell"
        ]
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tvm\n",
        "from tvm.script import ir as I\n",
        "from tvm.script import tir as T\n",
        "\n",
        "\n",
        "@I.ir_module\n",
        "class MyModule:\n",
        "    @T.prim_func\n",
        "    def mm_relu(\n",
        "        A: T.Buffer((128, 128), \"float32\"),\n",
        "        B: T.Buffer((128, 128), \"float32\"),\n",
        "        C: T.Buffer((128, 128), \"float32\"),\n",
        "    ):\n",
        "        Y = T.alloc_buffer((128, 128), dtype=\"float32\")\n",
        "        for i in range(128):\n",
        "            for j in range(128):\n",
        "                for k in range(128):\n",
        "                    with T.block(\"Y\"):\n",
        "                        vi = T.axis.spatial(128, i)\n",
        "                        vj = T.axis.spatial(128, j)\n",
        "                        vk = T.axis.reduce(128, k)\n",
        "                        T.reads(A[vi, vk], B[vk, vj])\n",
        "                        T.writes(Y[vi, vj])\n",
        "                        with T.init():\n",
        "                            Y[vi, vj] = T.float32(0)\n",
        "                        Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]\n",
        "        for i in range(128):\n",
        "            for j in range(128):\n",
        "                with T.block(\"C\"):\n",
        "                    vi = T.axis.spatial(128, i)\n",
        "                    vj = T.axis.spatial(128, j)\n",
        "                    T.reads(Y[vi, vj])\n",
        "                    T.writes(C[vi, vj])\n",
        "                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 使用语法糖的简洁写法\n",
        "\n",
        "为了便于编写，可以采用以下语法糖来简化代码：\n",
        "\n",
        "- 利用 ``T.grid`` 来压缩嵌套循环；\n",
        "- 使用 ``T.axis.remap`` 来简化块迭代器注释；\n",
        "- 排除 ``T.reads`` 和 ``T.writes`` 对于可以从块体中推断出内容的块；"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-cell"
        ]
      },
      "outputs": [],
      "source": [
        "@I.ir_module\n",
        "class ConciseModule:\n",
        "    @T.prim_func\n",
        "    def mm_relu(\n",
        "        A: T.Buffer((128, 128), \"float32\"),\n",
        "        B: T.Buffer((128, 128), \"float32\"),\n",
        "        C: T.Buffer((128, 128), \"float32\"),\n",
        "    ):\n",
        "        Y = T.alloc_buffer((128, 128), dtype=\"float32\")\n",
        "        for i, j, k in T.grid(128, 128, 128):\n",
        "            with T.block(\"Y\"):\n",
        "                vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n",
        "                with T.init():\n",
        "                    Y[vi, vj] = T.float32(0)\n",
        "                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]\n",
        "        for i, j in T.grid(128, 128):\n",
        "            with T.block(\"C\"):\n",
        "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
        "                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "可以使用以下代码来验证这两个模块是等价的："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "True\n"
          ]
        }
      ],
      "source": [
        "print(tvm.ir.structural_equal(MyModule, ConciseModule))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 与 Python 变量交互\n",
        "\n",
        "尽管 TVMScript 不由 Python 解释器执行，但与 Python 的有限交互是可行的。例如，可以使用 Python 变量来确定 TensorIR 的形状和数据类型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-cell"
        ]
      },
      "outputs": [],
      "source": [
        "# Python variables\n",
        "M = N = K = 128\n",
        "dtype = \"float32\"\n",
        "\n",
        "\n",
        "# IRModule in TVMScript\n",
        "@I.ir_module\n",
        "class ConciseModuleFromPython:\n",
        "    @T.prim_func\n",
        "    def mm_relu(\n",
        "        A: T.Buffer((M, K), dtype),\n",
        "        B: T.Buffer((K, N), dtype),\n",
        "        C: T.Buffer((M, N), dtype),\n",
        "    ):\n",
        "        Y = T.alloc_buffer((M, N), dtype)\n",
        "        for i, j, k in T.grid(M, N, K):\n",
        "            with T.block(\"Y\"):\n",
        "                vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n",
        "                with T.init():\n",
        "                    Y[vi, vj] = T.cast(T.float32(0), dtype)\n",
        "                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]\n",
        "        for i, j in T.grid(M, N):\n",
        "            with T.block(\"C\"):\n",
        "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
        "                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "检查等价性："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "True\n"
          ]
        }
      ],
      "source": [
        "print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 具有动态形状的 TensorIR 函数\n",
        "\n",
        "尽管 TVMScript 不由 Python 解释器执行，但与 Python 的有限交互是可行的。例如，可以使用 Python 变量来确定 TensorIR 的形状和数据类型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-cell"
        ]
      },
      "outputs": [],
      "source": [
        "@I.ir_module\n",
        "class DynamicShapeModule:\n",
        "    @T.prim_func\n",
        "    def mm_relu(a: T.handle, b: T.handle, c: T.handle):\n",
        "        # Dynamic shape definition\n",
        "        M, N, K = T.int32(), T.int32(), T.int32()\n",
        "\n",
        "        # Bind the input buffers with the dynamic shapes\n",
        "        A = T.match_buffer(a, [M, K], dtype)\n",
        "        B = T.match_buffer(b, [K, N], dtype)\n",
        "        C = T.match_buffer(c, [M, N], dtype)\n",
        "        Y = T.alloc_buffer((M, N), dtype)\n",
        "        for i, j, k in T.grid(M, N, K):\n",
        "            with T.block(\"Y\"):\n",
        "                vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n",
        "                with T.init():\n",
        "                    Y[vi, vj] = T.cast(T.float32(0), dtype)\n",
        "                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]\n",
        "        for i, j in T.grid(M, N):\n",
        "            with T.block(\"C\"):\n",
        "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
        "                C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在让我们检查运行时动态形状推断：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-output"
        ]
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[[0.8489785  0.9980212  1.4608577  1.2539588 ]\n",
            " [0.87849003 1.4613248  1.5408815  1.1791621 ]\n",
            " [1.1160902  0.8524983  0.8697982  0.9582412 ]\n",
            " [0.80623764 1.3069451  1.3202583  0.99451154]]\n",
            "[[35.642204 33.659966 34.592144 ... 34.283146 34.511303 33.300488]\n",
            " [33.99455  29.586454 28.126196 ... 32.934216 28.9799   31.524603]\n",
            " [32.142868 30.852024 32.131252 ... 32.32186  29.78737  30.77026 ]\n",
            " ...\n",
            " [35.054104 33.21723  30.768831 ... 30.570036 31.797892 33.43776 ]\n",
            " [35.745453 29.975988 32.201622 ... 33.83027  30.723768 31.028706]\n",
            " [34.69189  33.023502 31.475616 ... 32.04393  31.498629 31.942545]]\n"
          ]
        }
      ],
      "source": [
        "def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int):\n",
        "    A = tvm.runtime.tensor(np.random.uniform(size=(m, k)).astype(\"float32\"))\n",
        "    B = tvm.runtime.tensor(np.random.uniform(size=(k, n)).astype(\"float32\"))\n",
        "    C = tvm.runtime.tensor(np.zeros((m, n), dtype=\"float32\"))\n",
        "    lib(A, B, C)\n",
        "    return C.numpy()\n",
        "\n",
        "\n",
        "# Compile lib only once\n",
        "dyn_shape_lib = tvm.build(DynamicShapeModule, target=\"llvm\")\n",
        "# Able to handle different shapes\n",
        "print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4))\n",
        "print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用张量表达式创建TensorIR\n",
        "通常，为了更简洁地表达计算，会忽略TensorIR的具体内容，从而导致了TensorIR的实际生成。这就是张量表达式（TE）的相关之处。\n",
        "\n",
        "张量表达式（TE）是一种特定领域的语言，通过类似表达式的API描述一系列计算。\n",
        "\n",
        "```{note}\n",
        "张量表达式在TVM堆栈中包含两个组件：表达式和调度。表达式是体现计算模式的特定领域语言，正是我们在本节中讨论的内容。相反，TE调度是传统的调度方法，已被TVM Unity堆栈中的TensorIR调度所取代。\n",
        "```\n",
        "\n",
        "### 创建静态形状函数\n",
        "\n",
        "使用上一小节中的``mm_relu``示例来演示TE创建方法。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm import te\n",
        "\n",
        "A = te.placeholder((128, 128), \"float32\", name=\"A\")\n",
        "B = te.placeholder((128, 128), \"float32\", name=\"B\")\n",
        "k = te.reduce_axis((0, 128), \"k\")\n",
        "Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name=\"Y\")\n",
        "C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name=\"C\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "这里``te.compute``采用的签名是``te.compute(output_shape, fcompute)``。\n",
        "fcompute函数描述了我们如何计算给定索引的每个元素``Y[i, j]``的值：\n",
        "\n",
        "上述lambda表达式封装了计算：$Y_{i, j} = \\sum_k A_{i, k} \\times B_{k, j}$。定义了计算后，我们可以通过纳入相关的感兴趣参数来构建TensorIR函数。\n",
        "在这个特定实例中，我们的目标是构建一个具有两个输入参数**A, B**和一个输出参数**C**的函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-output"
        ]
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">mm_relu</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        Y <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>alloc_buffer((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>))\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;Y&quot;</span>):\n",
              "                v_i, v_j, v_k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i, j, k])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[v_i, v_k], B[v_k, v_j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(Y[v_i, v_j])\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>init():\n",
              "                    Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
              "                Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">+</span> A[v_i, v_k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> B[v_k, v_j]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;C&quot;</span>):\n",
              "                v_i, v_j <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(Y[v_i, v_j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(C[v_i, v_j])\n",
              "                C[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>max(Y[v_i, v_j], T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>))\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "te_func = te.create_prim_func([A, B, C]).with_attr({\"global_symbol\": \"mm_relu\"})\n",
        "TEModule = tvm.IRModule({\"mm_relu\": te_func})\n",
        "TEModule.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 创建动态形状函数\n",
        "\n",
        "还可以使用张量表达式创建动态形状函数。唯一的区别是需要将输入张量的形状指定为符号变量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false,
        "tags": [
          "hide-output"
        ]
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">mm_relu</span>(var_A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, var_B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, var_C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        m, n <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32(), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        A <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_A, (m, n))\n",
              "        k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
              "        B <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_B, (k, n))\n",
              "        C <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_C, (m, n))\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        Y <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>alloc_buffer((m, n))\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j, k_1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(m, n, k):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;Y&quot;</span>):\n",
              "                v_i, v_j, v_k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i, j, k_1])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[v_i, v_k], B[v_k, v_j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(Y[v_i, v_j])\n",
              "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>init():\n",
              "                    Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
              "                Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> Y[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">+</span> A[v_i, v_k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> B[v_k, v_j]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(m, n):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;C&quot;</span>):\n",
              "                v_i, v_j <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(Y[v_i, v_j])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(C[v_i, v_j])\n",
              "                C[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>max(Y[v_i, v_j], T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>))\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Declare symbolic variables\n",
        "M, N, K = te.var(\"m\"), te.var(\"n\"), te.var(\"k\")\n",
        "A = te.placeholder((M, N), \"float32\", name=\"A\")\n",
        "B = te.placeholder((K, N), \"float32\", name=\"B\")\n",
        "k = te.reduce_axis((0, K), \"k\")\n",
        "Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name=\"Y\")\n",
        "C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name=\"C\")\n",
        "\n",
        "dyn_te_func = te.create_prim_func([A, B, C]).with_attr({\"global_symbol\": \"mm_relu\"})\n",
        "DynamicTEModule = tvm.IRModule({\"mm_relu\": dyn_te_func})\n",
        "DynamicTEModule.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "ai",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
